{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 混淆矩阵\n",
    "#### 适用于不平衡的数据分类问题\n",
    "行代表真实值, 列代表预测值\n",
    "\n",
    "|           |    0 | 1  |\n",
    "| :-------- | --------:| :--: |\n",
    "| 0  | TN | FP  |\n",
    "| 1  | FN |  TP |\n",
    "\n",
    "1 一般是我们关注的事件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 精准率 $precision = \\frac{TP}{TP+FP}$ 预测10次,这10次正确的概率\n",
    "### 召回率 $recall = \\frac{TP}{TP+FN}$ 关注的事件发生了, 我们成功预测了多少"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 实现混淆矩阵\\精准率\\召回率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "digits = datasets.load_digits()\n",
    "X = digits.data\n",
    "y = digits.target.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1617,)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 手动构造偏斜skew数据\n",
    "y[digits.target==9] = 1\n",
    "y[digits.target!=9] = 0\n",
    "y[y!=1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X_train,X_test, y_train,y_test = train_test_split(X,y, random_state=666)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9755555555555555"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "log_reg = LogisticRegression()\n",
    "log_reg.fit(X_train, y_train)\n",
    "log_reg.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_log_predict = log_reg.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "403"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def TN(y_true, y_predict):\n",
    "    assert len(y_true) == len(y_predict)\n",
    "    return np.sum((y_true==0) & (y_predict==0))\n",
    "\n",
    "TN(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def FP(y_true, y_predict):\n",
    "    assert len(y_true) == len(y_predict)\n",
    "    return np.sum((y_true==0) & (y_predict==1))\n",
    "\n",
    "FP(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def FN(y_true, y_predict):\n",
    "    assert len(y_true) == len(y_predict)\n",
    "    return np.sum((y_true==1) & (y_predict==0))\n",
    "\n",
    "FN(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "36"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def TP(y_true, y_predict):\n",
    "    assert len(y_true) == len(y_predict)\n",
    "    return np.sum((y_true==1) & (y_predict==1))\n",
    "\n",
    "TP(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[403,   2],\n",
       "       [  9,  36]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def confusion_matrix(y_true, y_predict):\n",
    "    return np.array([[TN(y_true, y_predict),FP(y_true, y_predict)],\n",
    "                    [FN(y_true, y_predict),TP(y_true, y_predict)]])\n",
    "\n",
    "confusion_matrix(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9473684210526315"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def precision_score(y_true, y_predict):\n",
    "    tp = TP(y_true, y_predict)\n",
    "    fp = FP(y_true, y_predict)\n",
    "    try:\n",
    "        return tp/(tp+fp)\n",
    "    except:\n",
    "        return 0.0\n",
    "\n",
    "precision_score(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def recall_score(y_true, y_predict):\n",
    "    tp = TP(y_true, y_predict)\n",
    "    fn = FN(y_true, y_predict)\n",
    "    try:\n",
    "        return tp/(tp+fn)\n",
    "    except:\n",
    "        return 0.0\n",
    "\n",
    "recall_score(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scikit-learn中的混淆矩阵, 精准率, 召回率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[403,   2],\n",
       "       [  9,  36]], dtype=int64)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "confusion_matrix(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9473684210526315"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import precision_score\n",
    "\n",
    "precision_score(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import recall_score\n",
    "\n",
    "recall_score(y_test, y_log_predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# F1 Score兼顾精准率和召回率\n",
    "F1 Score是precision和recall的调和平均值\n",
    "$\\frac{1}{F1}=\\frac{1}{2}(\\frac{1}{precision}+\\frac{1}{recall})$\n",
    "## $F1 = \\frac{2\\cdot precision\\cdot recall}{precision+recall}$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def f1_score(precision, recall):\n",
    "    try:\n",
    "        return 2*precision*recall/(precision+recall)\n",
    "    except:\n",
    "        return 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision =0.5\n",
    "recall = 0.5\n",
    "f1_score(precision, recall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.18000000000000002"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision =0.1\n",
    "recall = 0.9\n",
    "f1_score(precision, recall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import datasets\n",
    "\n",
    "digits = datasets.load_digits()\n",
    "X = digits.data\n",
    "y = digits.target.copy()\n",
    "\n",
    "y[digits.target==9] = 1\n",
    "y[digits.target!=9] = 0\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X_train,X_test, y_train,y_test = train_test_split(X,y, random_state=666)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9755555555555555"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "log_reg = LogisticRegression()\n",
    "log_reg.fit(X_train, y_train)\n",
    "log_reg.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_predict = log_reg.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[403,   2],\n",
       "       [  9,  36]], dtype=int64)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "confusion_matrix(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9473684210526315\n",
      "0.8\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import precision_score,recall_score\n",
    "print(precision_score(y_test, y_predict))\n",
    "print(recall_score(y_test, y_predict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8674698795180723"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import f1_score\n",
    "\n",
    "f1_score(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# decision_function\n",
    "修改决策边界的 threshold 改变precision和recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-22.05698737, -33.02937619, -16.21332482, -80.37914497,\n",
       "       -48.25127218, -24.540052  , -44.3917185 , -25.04291075,\n",
       "        -0.9782965 , -19.71744559, -66.25140542, -51.095976  ,\n",
       "       -31.49346952, -46.05338301, -38.67870918, -29.80471224,\n",
       "       -37.58846002, -82.57567291, -37.81904399, -11.01163509,\n",
       "        -9.17440656, -85.13002979, -16.71617095, -46.23719399,\n",
       "        -5.32989793, -47.91764215, -11.66729562, -39.1959603 ,\n",
       "       -25.25292315, -14.36646664, -16.99782425, -28.91903142,\n",
       "       -34.33938565, -29.47599298,  -7.85811474,  -3.82096582,\n",
       "       -24.0815564 , -22.16363196, -33.61212604, -23.14019304,\n",
       "       -26.91801894, -62.38934517, -38.85684738, -66.77256306,\n",
       "       -20.14479378, -17.47885338, -18.06798544, -22.22223522,\n",
       "       -29.6230231 , -19.73173031,   1.49553174,   8.3207997 ,\n",
       "       -36.29301444, -42.50729754, -25.90460604, -34.9895844 ,\n",
       "        -8.42008139, -50.04726697, -51.48206437,  19.88957006,\n",
       "        -8.91885763, -31.99340597, -11.66095184,  -0.47142072,\n",
       "       -49.16129363, -46.23795979, -25.05391525, -19.61345705,\n",
       "       -36.16659621,  -3.12535997,  -3.91419491, -19.06041993,\n",
       "       -21.03311413, -41.52237703, -12.00624474, -33.89268786,\n",
       "       -35.8480038 , -30.60470043, -56.5164008 , -18.45470062,\n",
       "         4.51534687, -17.21606409, -76.65095395, -58.54518742,\n",
       "       -31.72085465, -29.9082927 , -33.31894472,  -9.08748208,\n",
       "       -47.64451523, -66.15300449, -16.95627019, -22.2490188 ,\n",
       "       -11.48956383, -18.10557885, -68.6539921 , -47.02575541,\n",
       "       -40.1187028 , -35.50211379, -17.19764932, -63.10278929,\n",
       "       -16.95440906, -55.10244043, -28.71254946, -68.81579964,\n",
       "       -68.31013818,  -6.2593413 , -25.84000562, -38.0088188 ,\n",
       "       -27.90913213, -15.44709197, -27.45894688, -19.59776357,\n",
       "        12.33461573, -23.03864745, -35.94459717, -30.02828245,\n",
       "       -70.06673561, -29.48722884, -52.9882261 , -24.97013377,\n",
       "       -12.32841017, -48.00988436,  -2.49962509, -59.92447003,\n",
       "       -31.18113001,  -8.65727553, -71.34896157, -57.01113408,\n",
       "       -21.09870155, -21.53854878, -69.34307327, -18.63516916,\n",
       "       -39.91423632, -57.26578736,  -0.84506505, -21.88377038,\n",
       "       -22.64111199, -29.21262004, -35.15691288, -20.25850058,\n",
       "       -11.40289306,   3.87280538,   6.09028571,   1.4289726 ,\n",
       "        -7.82703794, -39.35174406,  12.2105426 , -75.10175181,\n",
       "       -75.38162387, -50.4180661 , -11.55438639, -48.45866001,\n",
       "       -75.44071708, -29.98056782, -64.11581665,  -7.16586305,\n",
       "        -6.52450885, -18.9725755 , -33.71610844, -17.76218045,\n",
       "       -45.59374746, -33.53729863, -34.0868588 , -73.31509876,\n",
       "       -15.43455101,  12.16746764, -56.4592371 ,  -6.03197163,\n",
       "       -49.08439661, -16.54210007,  -2.05950744, -11.81037135,\n",
       "       -33.47401333, -50.77178197, -10.62900394, -17.67499343,\n",
       "        -5.0782291 , -25.25777961, -16.61514356,   3.91126433,\n",
       "       -46.75595147, -12.89879509, -25.7478947 , -16.31799546,\n",
       "       -23.55093093, -83.48235908,  -6.23507848, -19.83968371,\n",
       "       -20.06231572, -26.65461837, -27.11268575, -39.63719148,\n",
       "       -39.81292761, -27.43656856, -24.11825897, -21.24519524,\n",
       "       -10.49821789, -19.39895604, -41.9575907 , -43.62363899,\n",
       "       -16.06837178, -64.096118  , -24.7545985 , -56.57384448,\n",
       "       -13.50013791, -30.01572626,   3.93719362, -44.716989  ,\n",
       "        -8.6936455 ,   1.58877927,  -2.76244465, -11.91893716,\n",
       "         7.5878903 ,  -7.25882682, -46.73815536, -49.19660016,\n",
       "        -4.80422087, -19.61029902, -24.30536823, -48.98792045,\n",
       "       -14.98128519, -24.83600473, -16.93953756, -19.46778493,\n",
       "       -15.77204315, -17.00121886, -39.23691441, -31.37453355,\n",
       "        -9.42196008, -71.38157137, -22.17500148, -14.72983366,\n",
       "       -23.57983424, -34.49383369,  -1.17651479, -32.90820356,\n",
       "       -10.82271498, -18.2622894 ,  -8.29308831, -44.84196332,\n",
       "       -22.5924786 , -61.73630119, -47.12968511, -65.62582626,\n",
       "       -33.36434923, -24.00481749, -29.33164745, -65.22704501,\n",
       "         1.43987441,  -4.56085463, -25.25846467, -22.46482704,\n",
       "       -54.43071124, -16.81737672, -11.28768458, -35.25837399,\n",
       "        -5.57317444, -14.93089605, -70.95372825,  -6.50505851,\n",
       "        -1.22953268, -37.87545059, -23.68942044, -68.29961815,\n",
       "        14.93804184, -62.55687196,  10.14794474, -24.44796256,\n",
       "       -32.85380848, -14.32956257, -85.68605241, -13.16395829,\n",
       "         9.27784964, -17.32719424, -36.06510105, -17.0471784 ,\n",
       "       -19.7131253 , -32.72634943,  -5.36343681,   7.68322704,\n",
       "         9.20405169,   5.76533231, -35.96345868, -13.02388752,\n",
       "       -54.87487205, -41.61763047,   5.93738017, -79.11920794,\n",
       "       -16.01401254, -19.72189599, -10.96333437, -42.55206978,\n",
       "       -19.70957086, -16.20502082, -18.68732174, -17.94402868,\n",
       "        -7.17459758, -20.54727195, -16.81068574, -70.69029118,\n",
       "        -9.817761  , -32.8703642 , -18.97769751, -21.37930925,\n",
       "       -25.15047454, -17.10978682, -13.5237094 , -23.76118253,\n",
       "        11.36508032, -14.50017129, -33.86299099, -13.71702129,\n",
       "       -50.52171739, -20.26632313, -56.12696072, -29.24268867,\n",
       "       -22.10082549, -31.3932142 , -68.99340534, -60.34419934,\n",
       "        14.35287383,   8.69508987, -25.31399645,   2.38295671,\n",
       "         5.04573313, -19.56492305, -59.19921571, -10.05787895,\n",
       "       -29.66208016, -27.40192476,   6.13016734, -80.46964618,\n",
       "       -34.87537927, -49.84646393, -36.03963638, -48.50250495,\n",
       "       -19.96806322, -62.05772244,  -3.23793339, -25.3290881 ,\n",
       "       -65.14032918,  -9.4273078 , -23.31742945,  19.3862689 ,\n",
       "       -18.84543883,  -4.47308053, -13.77207332, -21.88091197,\n",
       "       -43.41397204, -51.85062785, -28.83912826, -13.90476288,\n",
       "        -2.51952947,  -6.1601764 ,   3.14870387, -15.3399438 ,\n",
       "       -41.16626412, -25.89745989,  14.3019736 , -17.88815371,\n",
       "        14.67463162, -33.65791417,   4.8244983 , -14.42660021,\n",
       "       -54.22944449, -50.49124644, -30.54681992, -38.72563279,\n",
       "       -23.46176163, -24.87716619, -14.50554596, -23.72451893,\n",
       "       -28.07007877, -19.63714298, -28.66181355, -20.37688578,\n",
       "       -32.16743375, -11.15573432, -17.95922281, -24.54353174,\n",
       "       -24.60831453,  10.73692592, -16.68581903, -38.50772697,\n",
       "       -15.87674066, -37.05234203, -15.79366978, -68.69482285,\n",
       "       -33.64808174, -43.60838876, -28.74746984,  -9.889897  ,\n",
       "       -67.16453477, -33.49880963, -45.89913449, -14.36736647,\n",
       "       -38.28987967, -14.76245782, -70.44232899, -11.19628972,\n",
       "       -41.46527867, -32.38989412, -20.86064198, -27.68981429,\n",
       "       -16.06078182, -31.9631644 ,  -8.48419813, -22.10449658,\n",
       "       -34.06026334, -12.47049546, -36.15120085, -36.57954185,\n",
       "       -22.46157054,   4.4753967 , -20.80765471,  -3.75031509,\n",
       "       -20.31646516, -32.67824182, -41.10705975, -25.46015961,\n",
       "       -19.73666755, -47.83297157, -29.85781029, -45.24584906,\n",
       "       -71.65700018,  -5.93564181, -32.93705837,   1.89668235,\n",
       "        11.76385817,   7.35782315, -30.93185528, -63.94237748,\n",
       "       -23.41433687,  -5.43420909, -33.46407064, -24.11268205,\n",
       "       -67.49717861, -34.30051742, -34.23317026, -31.61584231,\n",
       "       -52.86791251, -22.89217794,  -8.16021531, -17.73972076,\n",
       "       -26.98681352, -32.38759687, -28.96085065, -67.25180475,\n",
       "       -46.49551892, -16.11283724])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decision_scores = log_reg.decision_function(X_test)\n",
    "decision_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19.889570056817906"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decision_scores.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-85.68605241251701"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decision_scores.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 这里修改 decision_scores 的阈值, 改变精准率和召回率\n",
    "y_predict_2 = np.array(decision_scores >= 5, dtype='int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[404,   1],\n",
       "       [ 21,  24]], dtype=int64)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_test, y_predict_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.96"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_test, y_predict_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5333333333333333"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_test, y_predict_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6857142857142858"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1_score(y_test, y_predict_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
